/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com>
 *
 */
#include "SMTPProtocol.h"
#include <iomanip> // setw

namespace aiengine {

// List of support commands
std::vector<SmtpCommandType> SMTPProtocol::commands_ {
        std::make_tuple("EHLO"      	,4,     "hellos"     	,0,	static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_EHLO)),
        std::make_tuple("AUTH LOGIN"  	,10,    "auth logins"  	,0,	static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_AUTH)),
        std::make_tuple("MAIL FROM:"    ,10,    "mail froms"	,0,	static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_MAIL)),
        std::make_tuple("RCPT TO:"      ,8,     "rcpt tos"      ,0,	static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_RCPT)),
        std::make_tuple("DATA"       	,4,     "datas"       	,0,	static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_DATA)),
        std::make_tuple("EXPN"         	,4,     "expandss"     	,0,	static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_EXPN)),
        std::make_tuple("VRFY"        	,4,     "verifys"       ,0,	static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_VRFY)),
        std::make_tuple("RSET"         	,4,     "resets"        ,0,	static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_RSET)),
        std::make_tuple("HELP"         	,4,     "helps"        	,0,	static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_HELP)),
        std::make_tuple("NOOP"         	,4,     "noops"        	,0,	static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_NOOP)),
        std::make_tuple("STARTTLS"    	,8,     "starttls"      ,0,	static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_STARTTLS)),
        std::make_tuple("QUIT"         	,4,     "quits"        	,0,	static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_QUIT))
};

SMTPProtocol::~SMTPProtocol() {

	anomaly_.reset();
}

bool SMTPProtocol::check(const Packet &packet) {

	// The first message comes from the server and have code 220
	const uint8_t *payload = packet.getPayload();

	if ((payload[0] == '2')and(payload[1] == '2')and(payload[2] == '0')and
		((packet.getSourcePort() == 25) or
		(packet.getSourcePort() == 2525) or
		(packet.getSourcePort() == 587))) {

		++total_valid_packets_;
		return true;
	} else {
		++total_invalid_packets_;
		return false;
	}
}

void SMTPProtocol::setDynamicAllocatedMemory(bool value) {

	info_cache_->setDynamicAllocatedMemory(value);
	from_cache_->setDynamicAllocatedMemory(value);
	to_cache_->setDynamicAllocatedMemory(value);
}

bool SMTPProtocol::isDynamicAllocatedMemory() const {

	return info_cache_->isDynamicAllocatedMemory();
}

uint64_t SMTPProtocol::getCurrentUseMemory() const {

	uint64_t mem = sizeof(SMTPProtocol);

	mem += info_cache_->getCurrentUseMemory();
	mem += from_cache_->getCurrentUseMemory();
	mem += to_cache_->getCurrentUseMemory();

	return mem;
}

uint64_t SMTPProtocol::getAllocatedMemory() const {

	uint64_t mem = sizeof(SMTPProtocol);

        mem += info_cache_->getAllocatedMemory();
        mem += from_cache_->getAllocatedMemory();
        mem += to_cache_->getAllocatedMemory();

        return mem;
}

uint64_t SMTPProtocol::getTotalAllocatedMemory() const {

        uint64_t mem = getAllocatedMemory();

	mem += compute_memory_used_by_maps();

	return mem;
}

uint64_t SMTPProtocol::compute_memory_used_by_maps() const {

	uint64_t bytes = 0;

	std::for_each (from_map_.begin(), from_map_.end(), [&bytes] (PairStringCacheHits const &f) {
		bytes += f.first.size();
	});
	std::for_each (to_map_.begin(), to_map_.end(), [&bytes] (PairStringCacheHits const &t) {
		bytes += t.first.size();
	});
	return bytes;
}

uint32_t SMTPProtocol::getTotalCacheMisses() const {

	uint32_t miss = 0;

	miss = info_cache_->getTotalFails();
	miss += from_cache_->getTotalFails();
	miss += to_cache_->getTotalFails();

	return miss;
}

void SMTPProtocol::releaseCache() {

	if (FlowManagerPtr fm = flow_mng_.lock(); fm) {
		auto ft = fm->getFlowTable();

		std::ostringstream msg;
        	msg << "Releasing " << name() << " cache";

		infoMessage(msg.str());

		uint64_t total_cache_bytes_released = compute_memory_used_by_maps();
		uint64_t total_bytes_released_by_flows = 0;
		uint64_t total_cache_save_bytes = 0;
		uint32_t release_flows = 0;
		uint32_t release_froms = from_map_.size();
		uint32_t release_tos = to_map_.size();

                for (auto &flow: ft) {
                    	if (SharedPointer<SMTPInfo> info = flow->getSMTPInfo(); info) {
                                total_bytes_released_by_flows += sizeof(info);

                                flow->layer7info.reset();
                                ++release_flows;
                                info_cache_->release(info);
                        }
                }
                // Some entries can be still on the maps and needs to be
                // retrieve to their existing caches
                for (auto &entry: from_map_) {
			total_cache_save_bytes += entry.second.sc->size() * (entry.second.hits - 1);
                        releaseStringToCache(from_cache_, entry.second.sc);
		}
                from_map_.clear();

                for (auto &entry: to_map_) {
			total_cache_save_bytes += entry.second.sc->size() * (entry.second.hits - 1);
                        releaseStringToCache(to_cache_, entry.second.sc);
		}
                to_map_.clear();

		data_time_ = boost::posix_time::microsec_clock::local_time();

                msg.str("");
                msg << "Release " << release_froms;
                msg << " froms, " << release_tos << " tos, " << release_flows << " flows";
		computeMemoryUtilization(msg, total_cache_bytes_released, total_bytes_released_by_flows, total_cache_save_bytes);
                infoMessage(msg.str());
	}
}

void SMTPProtocol::releaseFlowInfo(Flow *flow) {

	if (SharedPointer<SMTPInfo> info = flow->getSMTPInfo(); info)
		info_cache_->release(info);
}

void SMTPProtocol::attach_from(SMTPInfo *info, const boost::string_ref &from) {

	if (!info->from) {
                if (StringMap::iterator it = from_map_.find(from); it != from_map_.end()) {
                        ++(it->second).hits;
                        info->from = (it->second).sc;
		} else {
                        if (SharedPointer<StringCache> from_ptr = from_cache_->acquire(); from_ptr) {
                                from_ptr->name(from.data(), from.length());
                                info->from = from_ptr;
                                from_map_.insert(std::make_pair(from_ptr->name(), from_ptr));
                        }
                }
        }
}

void SMTPProtocol::handle_cmd_mail(SMTPInfo *info, const boost::string_ref &header) {

	SharedPointer<StringCache> from_ptr = info->from;

	size_t start = strlen("MAIL FROM:");
	size_t end = header.length() - 2;

	if (end - start >= MaxSMTPEmailLength) {
		++total_events_;
                if (current_flow_->getPacketAnomaly() == PacketAnomalyType::NONE)
                        current_flow_->setPacketAnomaly(PacketAnomalyType::SMTP_LONG_EMAIL);

		anomaly_->incAnomaly(current_flow_, PacketAnomalyType::SMTP_LONG_EMAIL);
		return;
	}

	if (header[start + 1] == '<')
		++start;

	if (header[end - 1] == '>')
		--end;

	boost::string_ref from(header.substr(start + 1, end - start - 1));

	size_t token = from.find_first_of("@");

	if (token > from.length()) {
		++total_events_;
                if (current_flow_->getPacketAnomaly() == PacketAnomalyType::NONE)
                        current_flow_->setPacketAnomaly(PacketAnomalyType::SMTP_BOGUS_HEADER);

		anomaly_->incAnomaly(current_flow_, PacketAnomalyType::SMTP_BOGUS_HEADER);
		return;
	}
	boost::string_ref domain(from.substr(token + 1, from.size()));

	if (ban_domain_mng_) {
                if (auto dom_candidate = ban_domain_mng_->getDomainName(domain); dom_candidate) {
                        ++total_ban_domains_;
			info->setIsBanned(true);
                        return;
                }
        }
        ++total_allow_domains_;

	attach_from(info, from);

	if (domain_mng_) {
        	if (auto dom_candidate = domain_mng_->getDomainName(domain); dom_candidate) {
			++total_events_;
			info->matched_domain_name = dom_candidate;
#if defined(BINDING)
                        if (dom_candidate->call.haveCallback())
                       		dom_candidate->call.executeCallback(current_flow_);
#endif
                }
	}
}

void SMTPProtocol::handle_cmd_rcpt(SMTPInfo *info, const boost::string_ref &header) {

	if (!info->to) {
        	size_t start = strlen("RCPT TO:");
        	size_t end = header.length() - 2;

        	if (end - start >= MaxSMTPEmailLength) {
                	++total_events_;
                	if (current_flow_->getPacketAnomaly() == PacketAnomalyType::NONE)
                        	current_flow_->setPacketAnomaly(PacketAnomalyType::SMTP_LONG_EMAIL);

                	anomaly_->incAnomaly(current_flow_, PacketAnomalyType::SMTP_LONG_EMAIL);
                	return;
        	}

		if (header[start + 1] == '<')
			++start;

		if (header[end - 1] == '>')
			--end;

		boost::string_ref to(header.substr(start + 1, end - start - 1));

                if (StringMap::iterator it = to_map_.find(to); it != to_map_.end()) {
                        ++(it->second).hits;
                        info->to = (it->second).sc;
		} else {
                        if (SharedPointer<StringCache> to_ptr = to_cache_->acquire(); to_ptr) {
                                to_ptr->name(to.data(), to.length());
                                info->to = to_ptr;
                                to_map_.insert(std::make_pair(to_ptr->name(), to_ptr));
                        }
                }
        }
}

void SMTPProtocol::process_payloadl7(Flow * flow, SMTPInfo *info, const boost::string_ref &payloadl7) {

        // The Flow have attached a mached DomainName
        if (info->matched_domain_name) {

                if (info->matched_domain_name->haveRegexManager()) {
                        if (!flow->regex_mng)
                                flow->regex_mng = info->matched_domain_name->getRegexManager();
                }

                eval_.processFlowPayloadLayer7(flow, payloadl7);
        }
}

void SMTPProtocol::processFlow(Flow *flow) {

	CPUCycle cycles(&total_cpu_cycles_);
	int length = flow->packet->getLength();
	const uint8_t *payload = flow->packet->getPayload();
	total_bytes_ += length;
	++total_packets_;

	setHeader(payload);

       	SharedPointer<SMTPInfo> info = flow->getSMTPInfo();

       	if (!info) {
               	if (info = info_cache_->acquire(); !info) {
			logFailCache(info_cache_->name(), flow);
			return;
               	}
        	flow->layer7info = info;
	}

        if (info->isBanned() == false) {

		current_flow_ = flow;

		if (flow->getFlowDirection() == FlowDirection::FORWARD) {

			if (info->isData()) { // The client is transfering the email
				boost::string_ref payloadl7(reinterpret_cast<const char*>(payload), length);

				info->incTotalDataBytes(length); /* Update the bytes */

				process_payloadl7(flow, info.get(), payloadl7);

				// Check if is the last data block
				if (int offset = length - 7; offset > 0) {
					if (std::memcmp(&payload[offset], "\x0d\x0a\x0d\x0a\x2e\x0d\x0a", 7) == 0) {
						info->incTotalDataBlocks();
						info->setIsData(false);
					}
				}
			} else { // Commands send by the client
				for (auto &command: commands_) {
					const char *c = std::get<0>(command);
					int offset = std::get<1>(command);

					if (std::memcmp(c, &header_[0], offset) == 0) {
						int32_t *hits = &std::get<3>(command);
						int8_t cmd = std::get<4>(command);

						++(*hits);
						++total_smtp_client_commands_;

						// Check if the commands are MAIL or RCPT
						if ( cmd == static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_MAIL)) {
							boost::string_ref header(reinterpret_cast<const char*>(header_), length);
							handle_cmd_mail(info.get(), header);
						} else if ( cmd == static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_RCPT)) {
							boost::string_ref header(reinterpret_cast<const char*>(header_), length);
							handle_cmd_rcpt(info.get(), header);
						} else if ( cmd == static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_DATA)) {
							info->setIsData(true);
						} else if ( cmd == static_cast<int8_t>(SMTPCommandTypes::SMTP_CMD_STARTTLS)) {
							info->setStartTLS(true);
							// Force to write on the databaseAdaptor update method
							flow->packet->setForceAdaptorWrite(true);
						}
						info->setCommand(cmd);
						return;
					}
				}
			}
		} else {
			// Responses from the server
			try {
				const char *header = reinterpret_cast<const char*>(header_);
				std::string value(header, 3);

				int code = std::stoi(value);

				++total_smtp_server_responses_;

				// The server agrees to start a SSL session for this Flow
				if ((info->isStartTLS()) and (code == 220)) {
					// Release the attached SMTPInfo object
					releaseFlowInfo(flow);
					// Reset the number of l7 packets, check SSLProtocol.cc
					flow->total_packets_l7 = 0;
					// Reset the forwarder so the next time will be a SSL flow
					flow->forwarder.reset();
				}
			} catch(std::invalid_argument&) { //or catch(...) to catch all exceptions
                	// We dont really do nothing here with code;
        		}
		}
	}
	return;
}

void SMTPProtocol::setDomainNameManager(const SharedPointer<DomainNameManager> &dm) {

        if (domain_mng_)
                domain_mng_->setPluggedToName("");

	if (dm) {
        	domain_mng_ = dm;
        	domain_mng_->setPluggedToName(name());
	} else {
		domain_mng_.reset();
	}
}

void SMTPProtocol::statistics(std::basic_ostream<char> &out, int level, int32_t limit) const {

	std::ios_base::fmtflags f(out.flags());

	showStatisticsHeader(out, level);

	if (level > 0) {
                if (ban_domain_mng_)
			out << "\t" << "Plugged banned domains from:" << ban_domain_mng_->name() << std::endl;
                if (domain_mng_)
			out << "\t" << "Plugged domains from:" << domain_mng_->name() << std::endl;
	}
	if (level > 3) {
		out << "\t" << "Total allow domains:    " << std::setw(10) << total_allow_domains_ << "\n"
			<< "\t" << "Total banned domains:   " << std::setw(10) << total_ban_domains_ << "\n"
			<< "\t" << "Total client commands:  " << std::setw(10) << total_smtp_client_commands_ << "\n"
			<< "\t" << "Total server responses: " << std::setw(10) << total_smtp_server_responses_ << std::endl;

		for (auto &command: commands_) {
			const char *label = std::get<2>(command);
			int32_t hits = std::get<3>(command);
			out << "\t" << "Total " << label << ":" << std::right << std::setfill(' ') << std::setw(27 - strlen(label)) << hits << std::endl;
		}
	}
	if ((level > 5)and(flow_forwarder_.lock()))
		flow_forwarder_.lock()->statistics(out);
	if (level > 3) {
		info_cache_->statistics(out);
		from_cache_->statistics(out);
		to_cache_->statistics(out);
		if (level > 4) {
			from_map_.show(out, "\t", limit);
			to_map_.show(out, "\t", limit);
		}
	}
	out.flags(f);
}

void SMTPProtocol::statistics(Json &out, int level) const {

	showStatisticsHeader(out, level);

        if (level > 3) {
		Json j;

		out["allow"] = total_allow_domains_;
		out["banned"] = total_ban_domains_;
		out["requests"] = total_smtp_client_commands_;
		out["responses"] = total_smtp_server_responses_;

		for (auto &command: commands_)
			j.emplace(std::get<2>(command), std::get<3>(command));

		out["commands"] = j;
        }
}

void SMTPProtocol::increaseAllocatedMemory(int value) {

	info_cache_->create(value);
	from_cache_->create(value);
	to_cache_->create(value);
}

void SMTPProtocol::decreaseAllocatedMemory(int value) {

	info_cache_->destroy(value);
	from_cache_->destroy(value);
	to_cache_->destroy(value);
}

CounterMap SMTPProtocol::getCounters() const {
	CounterMap cm;

        cm.addKeyValue("packets", total_packets_);
        cm.addKeyValue("bytes", total_bytes_);
        cm.addKeyValue("commands", total_smtp_client_commands_);
        cm.addKeyValue("responses", total_smtp_server_responses_);

        for (auto &command: commands_)
                cm.addKeyValue(std::get<2>(command), std::get<3>(command));

        return cm;
}

#if defined(PYTHON_BINDING) || defined(RUBY_BINDING)
#if defined(PYTHON_BINDING)
boost::python::dict SMTPProtocol::getCacheData(const std::string &name) const {
#elif defined(RUBY_BINDING)
VALUE SMTPProtocol::getCacheData(const std::string &name) const {
#endif
        if (boost::iequals(name, "from"))
		return addMapToHash(from_map_);
        else if (boost::iequals(name, "to"))
		return addMapToHash(to_map_);

	StringMap empty {"", ""};

        return addMapToHash(empty);
}

#if defined(PYTHON_BINDING)
SharedPointer<Cache<StringCache>> SMTPProtocol::getCache(const std::string &name) {

        if (boost::iequals(name, "from"))
                return from_cache_;
        else if (boost::iequals(name, "to"))
                return to_cache_;

        return nullptr;
}

#endif

#endif

void SMTPProtocol::statistics(Json &out, const std::string &map_name, int32_t limit) const {

        if (boost::iequals(map_name, "froms")) {
                for (auto &item: from_map_)
                        out.emplace(item.first, item.second.hits);
		return;
        }
        if (boost::iequals(map_name, "tos")) {
                for (auto &item: to_map_)
                        out.emplace(item.first, item.second.hits);
        }
}

void SMTPProtocol::resetCounters() {

	reset();

        total_events_ = 0;
        total_allow_domains_ = 0;
        total_ban_domains_ = 0;
        total_smtp_client_commands_ = 0;
        total_smtp_server_responses_ = 0;
	for (auto &command: commands_)
		std::get<3>(command) = 0;
}

} // namespace aiengine

